# -*- coding:utf-8 -*-

import tensorflow as tf
from config.glob.global_pool import global_pool

"""
常规损失函数和自定义损失函数
"""


def common_loss(ys=None, y_pred=None):
    """
    常规loss
    :return:
    """
    reduction = {
        'sum': tf.losses.Reduction.SUM,
        'mean': tf.losses.Reduction.MEAN
    }
    cfg_loss = global_pool.config.loss
    redu = reduction.get(cfg_loss.reduction)  # reduction

    if cfg_loss.name == 'softmax_cross_entropy':
        loss = tf.losses.softmax_cross_entropy(onehot_labels=ys, logits=y_pred, reduction=redu)
    elif cfg_loss.name == 'softmax_cross_entropy_with_logits':
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=ys, logits=y_pred)
    elif cfg_loss.name == 'hinge_loss':
        loss = tf.losses.hinge_loss(labels=ys, logits=y_pred, reduction=redu)
    else:
        raise ValueError('不支持该类型损失函数')
    return loss


def text_loss(ys, y_pred):
    """
    text损失函数
    :param ys:
    :param y_pred:
    :return:
    """
    ys_one_hot = tf.one_hot(ys, global_pool.embedding.vocab_size)  # todo
    ys_reshaped = tf.reshape(ys_one_hot, y_pred.get_shape())
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=ys_reshaped)
    loss = tf.reduce_mean(loss)
    return loss


def get_loss(ys, y_pred):
    """
    获取损失函数入口
    :param ys:
    :param y_pred:
    :return:
    """
    # 自定义损失函数
    if global_pool.config.loss.self_losser:
        loss = text_loss(ys, y_pred)
    # 常规损失函数
    else:
        loss = common_loss(ys, y_pred)
    return loss
